//===- flash_attn_fwd_test.cc -------------------------------*--- C++-*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "brt/backends/cuda/device/common/cuda_call.h"
#include "brt/backends/cuda/device/cuda_allocator.h"
#include "brt/backends/cuda/providers/default/cuda_provider.h"
#include "brt/core/session/request_context.h"
#include "brt/core/session/session.h"
#include "brt/test/common/cuda/util.h"
#include "gtest/gtest.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <fstream>
#include <vector>

static std::string test_file_flash_attn_fwd =
    "test/test_files/flash_attn_fwd.mlir";
// ground_truth_file and input files are generated by running
// generate_flash_attn_ground_truth.py at test/test_files/
static std::string input_q_file = "test/test_files/flash_attn_inputs_q.data";
static std::string input_k_file = "test/test_files/flash_attn_inputs_k.data";
static std::string input_v_file = "test/test_files/flash_attn_inputs_v.data";
static std::string ground_truth_file =
    "test/test_files/flash_attn_fwd_outputs.data";

using namespace brt;
using namespace brt::cuda;
using namespace brt::test;

TEST(SM80CUDATestFlashAttnFwd, Basic) {

  size_t b = 1;
  size_t seq_len = 128;
  size_t num_heads = 3;
  size_t head_dims = 32;
  size_t input_len = b * seq_len * num_heads * head_dims;
  size_t softmax_len = b * seq_len * num_heads;
  // size_t rng_state_len = 2;

  Session session;
  auto status_allocator = CUDAAllocatorFactory(&session);
  BRT_TEST_CHECK_STATUS(status_allocator);
  auto status_cuda = DefaultCUDAExecutionProviderFactory(&session);
  BRT_TEST_CHECK_STATUS(status_cuda);

  auto status_load = session.Load(test_file_flash_attn_fwd, "byre");
  BRT_TEST_CHECK_STATUS(status_load);

  std::unique_ptr<RequestContext> request;
  auto status_request = session.NewRequestContext(&request);
  BRT_TEST_CHECK_STATUS(status_request);

  __half *d_o;
  __half *d_q;
  __half *d_k;
  __half *d_v;
  float *d_softmax_lse;

  // rng_state
  // uint64_t *d_rng_state;
  // uint64_t h_rng_state[2];
  // h_rng_state[0] = 0UL;
  // h_rng_state[1] = 3000UL;

  cudaMalloc(&d_o, input_len * sizeof(__half));
  cudaMalloc(&d_q, input_len * sizeof(__half));
  cudaMalloc(&d_k, input_len * sizeof(__half));
  cudaMalloc(&d_v, input_len * sizeof(__half));
  cudaMalloc(&d_softmax_lse, softmax_len * sizeof(float));

  // cudaMalloc(&d_rng_state, rng_state_len * sizeof(uint64_t));
  // cudaMemcpy(d_rng_state, h_rng_state, rng_state_len * sizeof(uint64_t),
  // cudaMemcpyHostToDevice);

  ReadCUDAFloatValues(d_q, input_len, input_q_file);
  ReadCUDAFloatValues(d_k, input_len, input_k_file);
  ReadCUDAFloatValues(d_v, input_len, input_v_file);
  AssignCUDABuffer(d_softmax_lse, softmax_len, 0.f);
  AssignCUDABuffer(d_o, input_len, static_cast<__half>(0.f));

  cudaDeviceSynchronize();

  // PrintCUDAValues(d_o, input_len, input_len);
  // PrintCUDAValues(d_q, input_len, input_len);
  // PrintCUDAValues(d_k, input_len, input_len);
  // PrintCUDAValues(d_v, input_len, input_len);
  // PrintCUDAValues(d_softmax_lse, softmax_len, 10);

  request->BindArg(0, d_q);
  request->BindArg(1, d_k);
  request->BindArg(2, d_v);
  request->BindArg(3, d_o);
  request->BindArg(4, d_softmax_lse);
  // request->BindArg(6, d_rng_state);

  request->FinishIOBinding();

  auto status_run = session.Run(*request);
  BRT_TEST_CHECK_STATUS(status_run);
  auto status_sync = request->Sync();
  BRT_TEST_CHECK_STATUS(status_sync);

  // PrintCUDAValues(d_o, input_len, input_len);

  CheckCUDABuffer<__half>(
      (__half *)d_o, /* size */ input_len, [&](__half *h_ptr) {
        __half *ground_truth = new __half[input_len];
        std::ifstream inFile;
        inFile.open(ground_truth_file);
        if (inFile.is_open()) {
          float num;
          for (size_t i = 0; i < input_len; i++) {
            inFile >> num;
            // std::cout << "read:" << num << std::endl;
            ground_truth[i] = static_cast<__half>(num);
          }
        } else {
          ASSERT_TRUE(false)
              << "cannot open ground truth file of flash attn fwd output.";
        }
        inFile.close();
        float max_diff = 0.f;
        for (size_t i = 0; i < input_len; ++i) {
          if (abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]) >
              max_diff) {
            max_diff = abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]);
          }
          if (abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]) > 2e-6f) {
            std::cout << i << " " << h_ptr[i] << " " << ground_truth[i] << " "
                      << abs(h_ptr[i] - ground_truth[i]) / ground_truth[i]
                      << std::endl;
            EXPECT_TRUE(false);
          }
        }
        std::cout << "max_diff (ratio):" << max_diff << std::endl;
        delete[] ground_truth;
      });

  cudaFree(d_o);
  cudaFree(d_q);
  cudaFree(d_k);
  cudaFree(d_v);
  cudaFree(d_softmax_lse);
}

static std::string test_file_flash_attn_kvcache =
    "test/test_files/flash_attn_kvcache.mlir";
// ground_truth_file and input files are generated by running
// generate_flash_attn_ground_truth.py at test/test_files/
static std::string kvcache_input_q_file =
    "test/test_files/flash_attn_kvcache_inputs_q.data";
static std::string kvcache_input_k_file =
    "test/test_files/flash_attn_kvcache_inputs_k.data";
static std::string kvcache_input_v_file =
    "test/test_files/flash_attn_kvcache_inputs_v.data";
static std::string kvcache_input_kcache_file =
    "test/test_files/flash_attn_kvcache_inputs_kcache.data";
static std::string kvcache_input_vcache_file =
    "test/test_files/flash_attn_kvcache_inputs_vcache.data";
static std::string kvcache_input_cache_seqlens_file =
    "test/test_files/flash_attn_kvcache_inputs_cache_seqlens.data";
static std::string kvcache_ground_truth_file =
    "test/test_files/flash_attn_kvcache_outputs.data";
static std::string kvcache_ground_truth_kcache_file =
    "test/test_files/flash_attn_kvcache_outputs_kcache.data";
static std::string kvcache_ground_truth_vcache_file =
    "test/test_files/flash_attn_kvcache_outputs_vcache.data";

TEST(SM80CUDATestFlashAttnKVCache, Basic) {
  size_t b = 2;
  size_t seq_len = 128;
  size_t seq_len_q = 1;
  size_t num_heads = 3;
  size_t head_dims = 32;
  size_t input_len = b * seq_len_q * num_heads * head_dims;
  size_t softmax_len = b * seq_len_q * num_heads;
  size_t cache_len = b * seq_len * num_heads * head_dims;

  Session session;
  auto status_allocator = CUDAAllocatorFactory(&session);
  BRT_TEST_CHECK_STATUS(status_allocator);
  auto status_cuda = DefaultCUDAExecutionProviderFactory(&session);
  BRT_TEST_CHECK_STATUS(status_cuda);

  auto status_load = session.Load(test_file_flash_attn_kvcache, "byre");
  BRT_TEST_CHECK_STATUS(status_load);

  std::unique_ptr<RequestContext> request;
  auto status_request = session.NewRequestContext(&request);
  BRT_TEST_CHECK_STATUS(status_request);

  __half *d_o;
  __half *d_q;
  __half *d_k;
  __half *d_v;
  __half *d_kcache;
  __half *d_vcache;
  int32_t *d_seqlen;
  float *d_softmax_lse;

  cudaMalloc(&d_o, input_len * sizeof(__half));
  cudaMalloc(&d_q, input_len * sizeof(__half));
  cudaMalloc(&d_k, input_len * sizeof(__half));
  cudaMalloc(&d_v, input_len * sizeof(__half));
  cudaMalloc(&d_kcache, cache_len * sizeof(__half));
  cudaMalloc(&d_vcache, cache_len * sizeof(__half));
  cudaMalloc(&d_seqlen, b * sizeof(int32_t));
  cudaMalloc(&d_softmax_lse, softmax_len * sizeof(float));

  ReadCUDAFloatValues(d_q, input_len, kvcache_input_q_file);
  ReadCUDAFloatValues(d_k, input_len, kvcache_input_k_file);
  ReadCUDAFloatValues(d_v, input_len, kvcache_input_v_file);
  ReadCUDAFloatValues(d_kcache, cache_len, kvcache_input_kcache_file);
  ReadCUDAFloatValues(d_vcache, cache_len, kvcache_input_vcache_file);
  ReadCUDAIntegerValues(d_seqlen, b, kvcache_input_cache_seqlens_file);
  AssignCUDABuffer(d_softmax_lse, softmax_len, 0.f);
  AssignCUDABuffer(d_o, input_len, static_cast<__half>(0.f));

  cudaDeviceSynchronize();

  // PrintCUDAValues(d_o, input_len, input_len);
  // PrintCUDAValues(d_q, input_len, input_len);
  // PrintCUDAValues(d_k, input_len, input_len);
  // PrintCUDAValues(d_v, input_len, input_len);
  // PrintCUDAValues(d_softmax_lse, softmax_len, 10);

  request->BindArg(0, d_q);
  request->BindArg(1, d_kcache);
  request->BindArg(2, d_vcache);
  request->BindArg(3, d_k);
  request->BindArg(4, d_v);
  request->BindArg(5, d_seqlen);
  request->BindArg(6, d_o);
  request->BindArg(7, d_softmax_lse);

  request->FinishIOBinding();

  auto status_run = session.Run(*request);
  BRT_TEST_CHECK_STATUS(status_run);
  auto status_sync = request->Sync();
  BRT_TEST_CHECK_STATUS(status_sync);

  // PrintCUDAValues(d_o, input_len, input_len);

  CheckCUDABuffer<__half>(
      (__half *)d_o, /* size */ input_len, [&](__half *h_ptr) {
        __half *ground_truth = new __half[input_len];
        std::ifstream inFile;
        inFile.open(kvcache_ground_truth_file);
        if (inFile.is_open()) {
          float num;
          for (size_t i = 0; i < input_len; i++) {
            inFile >> num;
            // std::cout << "ground_truth[" << i << "] = " << num << std::endl;
            ground_truth[i] = static_cast<__half>(num);
          }
        } else {
          ASSERT_TRUE(false)
              << "cannot open ground truth file of flash attn fwd output.";
        }
        inFile.close();
        float max_diff = 0.f;
        for (size_t i = 0; i < input_len; ++i) {
          if (abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]) >
              max_diff) {
            max_diff = abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]);
          }
          if (abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]) > 2e-6f) {
            std::cout << i << " " << h_ptr[i] << " " << ground_truth[i] << " "
                      << abs(h_ptr[i] - ground_truth[i]) / ground_truth[i]
                      << std::endl;
            EXPECT_TRUE(false);
          }
        }
        std::cout << "max_diff (ratio):" << max_diff << std::endl;
        delete[] ground_truth;
      });

  // check kvcache update
  CheckCUDABuffer<__half>(
      (__half *)d_kcache, /* size */ cache_len, [&](__half *h_ptr) {
        __half *ground_truth = new __half[cache_len];
        std::ifstream inFile;
        inFile.open(kvcache_ground_truth_kcache_file);
        if (inFile.is_open()) {
          float num;
          for (size_t i = 0; i < cache_len; i++) {
            inFile >> num;
            // std::cout << "ground_truth[" << i << "] = " << num << std::endl;
            ground_truth[i] = static_cast<__half>(num);
          }
        } else {
          ASSERT_TRUE(false)
              << "cannot open ground truth file of flash attn fwd output.";
        }
        inFile.close();
        float max_diff = 0.f;
        for (size_t i = 0; i < cache_len; ++i) {
          if (abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]) >
              max_diff) {
            max_diff = abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]);
          }
          if (abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]) > 2e-6f) {
            std::cout << i << " " << h_ptr[i] << " " << ground_truth[i] << " "
                      << abs(h_ptr[i] - ground_truth[i]) / ground_truth[i]
                      << std::endl;
            EXPECT_TRUE(false);
          }
        }
        std::cout << "max_diff (ratio):" << max_diff << std::endl;
        delete[] ground_truth;
      });

  CheckCUDABuffer<__half>(
      (__half *)d_vcache, /* size */ cache_len, [&](__half *h_ptr) {
        __half *ground_truth = new __half[cache_len];
        std::ifstream inFile;
        inFile.open(kvcache_ground_truth_vcache_file);
        if (inFile.is_open()) {
          float num;
          for (size_t i = 0; i < cache_len; i++) {
            inFile >> num;
            // std::cout << "ground_truth[" << i << "] = " << num << std::endl;
            ground_truth[i] = static_cast<__half>(num);
          }
        } else {
          ASSERT_TRUE(false)
              << "cannot open ground truth file of flash attn fwd output.";
        }
        inFile.close();
        float max_diff = 0.f;
        for (size_t i = 0; i < cache_len; ++i) {
          if (abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]) >
              max_diff) {
            max_diff = abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]);
          }
          if (abs(h_ptr[i] - ground_truth[i]) / abs(ground_truth[i]) > 2e-6f) {
            std::cout << i << " " << h_ptr[i] << " " << ground_truth[i] << " "
                      << abs(h_ptr[i] - ground_truth[i]) / ground_truth[i]
                      << std::endl;
            EXPECT_TRUE(false);
          }
        }
        std::cout << "max_diff (ratio):" << max_diff << std::endl;
        delete[] ground_truth;
      });

  cudaFree(d_o);
  cudaFree(d_q);
  cudaFree(d_k);
  cudaFree(d_v);
  cudaFree(d_kcache);
  cudaFree(d_vcache);
  cudaFree(d_seqlen);
  cudaFree(d_softmax_lse);
}
